/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/conv_op.h"

#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {

std::vector<int64_t> ConvOp::ComputeOutputShape(
    framework::InferShapeContext* ctx) const {
  OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
  OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");

  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
  std::string padding_algorithm =
      ctx->Attrs().Get<std::string>("padding_algorithm");
  int groups = ctx->Attrs().Get<int>("groups");
  std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
  int dilation_size = dilations.size();
  for (int i = 0; i < dilation_size; ++i) {
    PADDLE_ENFORCE_GT(
        dilations[i],
        0,
        platform::errors::InvalidArgument(
            "The dilation of Op(Conv) should be larget than 0, but received "
            "dilation is %d.",
            dilations[i]));
  }
  const std::string data_format = ctx->Attrs().Get<std::string>("data_format");

  // MKL-DNN Kernels are using NCHW order of dims description
  // so we ignore data_format consideration for MKL-DNN kernel
  const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
                            (data_format == "NHWC" || data_format == "NDHWC");

  PADDLE_ENFORCE_EQ(
      in_dims.size() == 4 || in_dims.size() == 5,
      true,
      platform::errors::InvalidArgument(
          "The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
          "received: input's dimension is %u, input's shape is [%s].",
          in_dims.size(),
          in_dims));

  PADDLE_ENFORCE_EQ(
      in_dims.size(),
      filter_dims.size(),
      platform::errors::InvalidArgument(
          "The input's dimension and filter's dimension of "
          "Op(Conv) should be equal. But received: the input's shape is [%s], "
          "the input's dimension is %d; the filter's shape is [%s],  "
          "the filter's dimension is %d.",
          in_dims,
          in_dims.size(),
          filter_dims,
          filter_dims.size()));

  int stride_size = strides.size();
  for (int i = 0; i < stride_size; ++i) {
    PADDLE_ENFORCE_GT(
        strides[i],
        0,
        platform::errors::InvalidArgument(
            "The stride of Op(Conv) should be larget than 0, but received "
            "stride is %d.",
            strides[i]));
  }

  int in_sub_stride_size = in_dims.size() - stride_size;
  PADDLE_ENFORCE_EQ(
      in_dims.size(),
      strides.size() + 2U,
      platform::errors::InvalidArgument(
          "The difference of input's dimension and Attr(strides)'s "
          "length must be euqal to 2 for Op(Conv). "
          "But received: input's dimension is %d, input's shape is [%s]; "
          "Attr(stride)'s length is %d, Attr(stride) is [%s]; "
          "difference of input's dimention and Attr(strides)'s length = %u.",
          in_dims.size(),
          in_dims,
          strides.size(),
          phi::make_ddim(strides),
          in_sub_stride_size));

  const auto input_channels =
      channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];

  PADDLE_ENFORCE_EQ(
      input_channels,
      filter_dims[1] * groups,
      platform::errors::InvalidArgument(
          "The number of input's channels should be equal to filter's channels "
          "* groups for Op(Conv). But received: the input's channels is %d, "
          "the input's shape is [%s]; the filter's channels is %d, the "
          "filter's shape is [%s]; the groups is %d, the data_format is %s. "
          "The error may come from wrong data_format setting.",
          input_channels,
          in_dims,
          filter_dims[1],
          filter_dims,
          groups,
          data_format));
  PADDLE_ENFORCE_EQ(
      filter_dims[0] % groups,
      0,
      platform::errors::InvalidArgument(
          "The number of output's channels (filter's first dimension) of "
          "Op(Conv) should be divided by groups. But received: "
          "the output channels is %d, the filter's shape is [%s], "
          "the groups is %d.",
          filter_dims[0],
          filter_dims,
          groups));

  if (ctx->IsRuntime()) {
    PADDLE_ENFORCE_GT(
        filter_dims[0],
        0,
        platform::errors::InvalidArgument(
            "the size of filter at axis 0 should be greater than 0"));
  }

  framework::DDim in_data_dims;
  if (channel_last) {
    in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
  } else {
    in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
  }

  framework::DDim filter_data_dims =
      phi::slice_ddim(filter_dims, 2, filter_dims.size());

  std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
  UpdatePaddingAndDilation(
      &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);

  std::vector<int64_t> output_shape({in_dims[0]});
  if (!channel_last) {
    output_shape.push_back(filter_dims[0]);
  }
  for (int i = 0; i < in_data_dims.size(); ++i) {
    if ((!ctx->IsRuntime()) &&
        (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
      output_shape.push_back(-1);
    } else {
      output_shape.push_back(ConvOutputSize(in_data_dims[i],
                                            filter_data_dims[i],
                                            dilations[i],
                                            paddings[2 * i],
                                            paddings[2 * i + 1],
                                            strides[i]));
    }
  }
  if (channel_last) {
    output_shape.push_back(filter_dims[0]);
  }

  return output_shape;
}

phi::KernelKey ConvOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
  // todo enable data layout when it's ready
  // (https://github.com/PaddlePaddle/Paddle/pull/20042)

  if (input_data_type != framework::proto::VarType::INT8 &&
      input_data_type != framework::proto::VarType::UINT8 &&
      input_data_type != framework::proto::VarType::BF16) {
    auto filter_data_type = framework::TransToProtoVarType(
        ctx.Input<phi::DenseTensor>("Filter")->dtype());
    PADDLE_ENFORCE_EQ(
        input_data_type,
        filter_data_type,
        platform::errors::InvalidArgument(
            "input and filter data type should be consistent, "
            "but received input data type is %s and filter type "
            "is %s",
            paddle::framework::DataTypeToString(input_data_type),
            paddle::framework::DataTypeToString(filter_data_type)));
  }

  return phi::KernelKey(input_data_type, ctx.GetPlace());
}

phi::KernelKey ConvOp::GetKernelTypeForVar(
    const std::string& var_name,
    const phi::DenseTensor& tensor,
    const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
  // Only input require reshaping, weights and
  // bias are having shape in NCHW order
  if ((var_name == "Input") &&
      (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
      (tensor.layout() != phi::DataLayout::ONEDNN)) {
    auto attrs = Attrs();
    auto ar = paddle::framework::AttrReader(attrs);
    const std::string data_format = ar.Get<std::string>("data_format");
    auto dl = phi::StringToDataLayout(data_format);
    // Some models may have intentionally set "AnyLayout" for conv
    // op. Treat this as NCHW (default data_format value)
    if (dl != phi::DataLayout::kAnyLayout) {
      return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
    }
  }
#endif
  return phi::KernelKey(
      tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}

void Conv2DOpMaker::Make() {
  AddInput("Input",
           "(Tensor) The input tensor of convolution operator. "
           "The format of input tensor is NCHW or NHWC, where N is batch size, "
           "C is the "
           "number of channels, H is the height of the feature, "
           "and W is the width of the feature.");
  AddInput("Filter",
           "(Tensor) The filter tensor of convolution operator. "
           "The format of the filter tensor is MCHW, where M is the number of "
           "output image channels, C is the number of input image channels, "
           "H is the height of the filter, and W is the width of the filter. "
           "If the groups attribute is greater than 1, C equals the number of "
           "input image channels divided by the groups.");
  AddOutput("Output",
            "(Tensor) The output tensor of convolution operator. "
            "It has same data fromat and data type as the Input.");
  AddAttr<std::vector<int>>("strides",
                            "(vector<int> default:{1, 1}), the "
                            "strides(h_stride, w_stride) of "
                            "convolution operator.")
      .SetDefault({1, 1});
  AddAttr<std::vector<int>>("paddings",
                            "(vector<int> default:{0, 0}), the "
                            "paddings(pad_height_top, pad_height_bottom, "
                            "pad_width_left, pad_wifth_right)  of "
                            "convolution operator.")
      .SetDefault({0, 0});
  AddAttr<std::string>(
      "padding_algorithm",
      "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
      "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
      "Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
      .SetDefault("EXPLICIT");
  AddAttr<int>(
      "groups",
      "(int default:1), the groups number of the convolution operator. "
      "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
      "when group=2, the first half of the filters is only connected to the "
      "first half of the input channels, while the second half of the filters "
      "is only connected to the second half of the input channels.")
      .SetDefault(1);
  AddAttr<std::vector<int>>("dilations",
                            "(vector<int> default:{1, 1}), the "
                            "dilations(h_dilation, w_dilation) of "
                            "convolution operator.")
      .SetDefault({1, 1});
  AddAttr<std::string>(
      "data_format",
      "(string, default NCHW) Only used in "
      "An optional string from: \"NHWC\", \"NCHW\". "
      "Defaults to \"NHWC\". Specify the data format of the output data, "
      "the input will be transformed automatically. ")
      .SetDefault("NCHW");
  // TODO(dzhwinter): need to registered layout transform function
  AddComment(R"DOC(
Convolution Operator.

The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is
the width of the feature.
Filters(Input) is MCHW format format. Where M is the number of output image channels, C is
the number of input image channels, H is the height of the filter, and W
is the width of the filter.
Parameters(strides, paddings, dilations) are two elements. These two elements represent
height and width, respectively.
The input(X) size and output(Out) size may be different.

Example:
  Input:
       Input shape: $(N, C_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
  Output:
       Output shape: $(N, C_{out}, H_{out}, W_{out})$
  Where
$$
       H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
       W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
  Apply();
}

class DepthwiseConv2DOpMaker : public Conv2DOpMaker {
 protected:
  void Apply() override {
    AddAttr<bool>(
        "use_cudnn",
        "(bool, default false) Only used in cudnn kernel, need install cudnn")
        .SetDefault(false)
        .AsExtra();
  }
};

void Conv3DOpMaker::Make() {
  AddInput(
      "Input",
      "(Tensor) The input tensor of convolution operator. "
      "The format of input tensor is NCDHW or NDHWC. Where N is batch size, C "
      "is the "
      "number of channels, D is the depth of the feature, H is the height of "
      "the feature, "
      "and W is the width of the feature.");
  AddInput("Filter",
           "(Tensor) The filter tensor of convolution operator. "
           "The format of the filter tensor is MCDHW, where M is the number of "
           "output image channels, C is the number of input image channels, "
           "D is the depth of the filter, H is the height of the filter, and W "
           "is the width of the filter."
           "If the groups attribute is greater than 1, C equals the number of "
           "input image channels divided by the groups.");
  AddOutput("Output",
            "(Tensor) The output tensor of convolution operator."
            "It has same data fromat and data type as the Input.");
  AddAttr<std::vector<int>>("strides",
                            "(vector<int>, default:{1, 1, 1}), the "
                            "strides(d_stride, h_stride, w_stride) of "
                            "convolution operator.")
      .SetDefault({1, 1, 1});
  AddAttr<std::vector<int>>(
      "paddings",
      "(vector<int>, default:{0, 0, 0}), the "
      "paddings(pad_depth_front, pad_depth_back, pad_height_top, "
      "pad_height_bottom, pad_width_left, pad_width_right) of convolution "
      "operator.")
      .SetDefault({0, 0, 0});
  AddAttr<std::string>(
      "padding_algorithm",
      "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
      "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
      "Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
      .SetDefault("EXPLICIT");
  AddAttr<int>(
      "groups",
      "(int default:1), the groups number of the convolution operator. "
      "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
      "when group=2, the first half of the filters is only connected to the "
      "first half of the input channels, while the second half of the filters "
      "is only connected to the second half of the input channels.")
      .SetDefault(1);
  AddAttr<std::vector<int>>("dilations",
                            "(vector<int> default:{1, 1, 1}), the "
                            "dilations(d_dilation, h_dilation, w_dilation) of "
                            "convolution operator.")
      .SetDefault({1, 1, 1});
  AddAttr<std::string>(
      "data_format",
      "(string, default NCDHW) Only used in "
      "An optional string from: \"NDHWC\", \"NCDHW\". "
      "Defaults to \"NDHWC\". Specify the data format of the output data, "
      "the input will be transformed automatically. ")
      .SetDefault("NCDHW");
  AddComment(R"DOC(
Convolution3D Operator.

The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW or NDHWC format, where N is batch
size, C is the number of channels,D is the depth of the feature, H is the height of
the feature, and W is the width of the feature.
Filters(Input) is MCDHW format, where M is the number of output image channels,
C is the number of input image channels, D is the depth of the filter,
H is the height of the filter, and W is the width of the filter.
Parameters(strides, paddings, dilations) are three elements. These three elements
represent depth, height and width, respectively.
The input(X) size and output(Out) size may be different.

Example:
  Input:
       Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{out}, C_{in}, D_f, H_f, W_f)$
  Output:
       Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
  Where
  $$
       D_{out}= \frac{(D_{in} + pad_depth_front + pad_depth_back - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\
       H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\
       W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
  $$
)DOC");
  Apply();
}

void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
  }
}

phi::KernelKey ConvOpGrad::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
  return phi::KernelKey(data_type, ctx.GetPlace());
}

phi::KernelKey ConvOpGrad::GetKernelTypeForVar(
    const std::string& var_name,
    const phi::DenseTensor& tensor,
    const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
  // Only input require reshaping, weights and
  // bias are having shape in NCHW order
  if (((var_name == "Input") ||
       (var_name == framework::GradVarName("Output"))) &&
      (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
      (tensor.layout() != phi::DataLayout::ONEDNN)) {
    auto attrs = Attrs();
    auto ar = paddle::framework::AttrReader(attrs);
    const std::string data_format = ar.Get<std::string>("data_format");
    auto dl = phi::StringToDataLayout(data_format);
    // Some models may have intentionally set "AnyLayout" for pool
    // op. Treat this as NCHW (default data_format value)
    if (dl != phi::DataLayout::kAnyLayout) {
      return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
    }
  }
#endif
  return phi::KernelKey(
      tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}

template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> op) const override {
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));

    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));

    if (this->HasInput("Bias")) {
      op->SetInput("Bias", this->Input("Bias"));
      op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
    }
    op->SetAttrMap(this->Attrs());
  }
};

template <typename T>
class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> op) const override {
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));

    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));

    if (this->HasInput("ResidualData")) {
      op->SetInput("ResidualData", this->Input("ResidualData"));
    }

    op->SetAttrMap(this->Attrs());
  }
};

/*
 * Inputs:  I, W, dO, ddI, ddW
 * Outputs: ddO, dW, dI
 */
template <typename T>
class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> op) const override {
    op->SetType(this->ForwardOpType() + "_grad");
    // I, W, dO, ddI, ddW
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
    op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
    op->SetInput("DDFilter",
                 this->OutputGrad(framework::GradVarName("Filter")));

    // ddO, dI, dW
    // Unlike grad op, double grad op does not use name@GRAD@GRAD
    // as key of ops' inputs and outputs.
    auto ddx = this->OutputGrad(framework::GradVarName("Input"));
    auto ddw = this->OutputGrad(framework::GradVarName("Filter"));

    op->SetOutput("DDOutput",
                  ddx.empty()
                      ? this->EmptyInputGrad()
                      : this->InputGrad(framework::GradVarName("Output")));
    op->SetOutput(
        "DFilter",
        ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
    op->SetOutput(
        "DInput",
        ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));

    op->SetAttrMap(this->Attrs());
  }
};

/*
 * Inputs:  I, W, dO, ddI, ddW
 * Outputs: ddO, dW, dI
 */
template <typename T>
class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> op) const override {
    op->SetType(this->ForwardOpType() + "_grad");
    // I, W, dO, ddI, ddW
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
    op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
    op->SetInput("DDFilter",
                 this->OutputGrad(framework::GradVarName("Filter")));

    auto ddx = this->OutputGrad(framework::GradVarName("Input"));
    auto ddw = this->OutputGrad(framework::GradVarName("Filter"));

    op->SetOutput("DDOutput",
                  ddx.empty()
                      ? this->EmptyInputGrad()
                      : this->InputGrad(framework::GradVarName("Output")));
    op->SetOutput(
        "DFilter",
        ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
    op->SetOutput(
        "DInput",
        ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));

    op->SetAttrMap(this->Attrs());
  }
};

void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
  auto x_dims = ctx->GetInputDim("Input");
  auto w_dims = ctx->GetInputDim("Filter");
  auto do_dims = ctx->GetInputDim("DOutput");

  if (ctx->HasOutput("DDOutput") &&
      (ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
    ctx->SetOutputDim("DDOutput", do_dims);
  }
  if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
    ctx->SetOutputDim("DFilter", w_dims);
  }
  if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
    ctx->SetOutputDim("DInput", x_dims);
  }
}

phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
  return phi::KernelKey(data_type, ctx.GetPlace());
}

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d,
                  ops::ConvOp,
                  ops::Conv2DOpMaker,
                  ops::ConvOpInferVarType,
                  ops::Conv2DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad,
                  ops::ConvOpGrad,
                  ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);

// depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d,
                  ops::ConvOp,
                  ops::DepthwiseConv2DOpMaker,
                  ops::ConvOpInferVarType,
                  ops::Conv2DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad,
                  ops::ConvOpGrad,
                  ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad);

REGISTER_OPERATOR(conv3d,
                  ops::ConvOp,
                  ops::Conv3DOpMaker,
                  ops::ConvOpInferVarType,
                  ops::Conv3DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv3DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad,
                  ops::ConvOpGrad,
                  ops::Conv3DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv3DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);

REGISTER_OP_VERSION(conv2d).AddCheckpoint(
    R"ROC(
      Upgrade conv2d, add a new attribute [use_addto].
    )ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
        "use_addto",
        "In order to support new feature (inplace addto strategy) for "
        "gradient accumulation.",
        false));

REGISTER_OP_VERSION(depthwise_conv2d)
    .AddCheckpoint(
        R"ROC(
      Upgrade depthwise_conv2d, add a new attribute [use_addto].
    )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "use_addto",
            "In order to support new feature (inplace addto strategy) for "
            "gradient accumulation.",
            false));

REGISTER_OP_VERSION(conv3d).AddCheckpoint(
    R"ROC(
      Upgrade conv3d, add a new attribute [use_addto].
    )ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
        "use_addto",
        "In order to support new feature (inplace addto strategy) for "
        "gradient accumulation.",
        false));
